import torch
from torch import nn
import torch.nn.functional as F

from cbml_benchmark.losses.registry import LOSS

def binarize(T, nb_classes):
    T = T.cpu().numpy()
    import sklearn.preprocessing
    T = sklearn.preprocessing.label_binarize(
        T, classes = range(0, nb_classes)
    )
    T = torch.FloatTensor(T).cuda()
    return T

def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)
    return output


@LOSS.register('proxy_anchor_loss')
class Proxy_Anchor(torch.nn.Module):
    def __init__(self, cfg, mrg=0.1, alpha=32):
        torch.nn.Module.__init__(self)
        # Proxy Anchor Initialization
        self.proxies = torch.nn.Parameter(torch.randn(cfg.LOSSES.PROXY_ANCHOR_LOSS.nb_classes, cfg.LOSSES.PROXY_ANCHOR_LOSS.sz_embed).cuda())
        nn.init.kaiming_normal_(self.proxies, mode='fan_out')

        self.nb_classes = cfg.LOSSES.PROXY_ANCHOR_LOSS.nb_classes
        self.sz_embed = cfg.LOSSES.PROXY_ANCHOR_LOSS.sz_embed
        self.mrg = mrg
        self.alpha = alpha

    def forward(self, X, T):
        X = F.normalize(X,p=2,dim=1)
        P = self.proxies
        cos = F.linear(l2_norm(X), l2_norm(P))  # Calcluate cosine similarity

        P_one_hot = binarize(T=T, nb_classes=self.nb_classes)
        N_one_hot = 1 - P_one_hot

        pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
        neg_exp = torch.exp(self.alpha * (cos + self.mrg))

        with_pos_proxies = torch.nonzero(P_one_hot.sum(dim=0) != 0, as_tuple=False).squeeze(
            dim=1)  # The set of positive proxies of data in the batch
        num_valid_proxies = len(with_pos_proxies)  # The number of positive proxies

        P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0)
        N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)

        pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
        neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
        loss = pos_term + neg_term

        return loss